{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d658f909-e679-41e9-9c4e-e0241c719049",
   "metadata": {},
   "source": [
    "If you're not running in Saturn Cloud, you need to install these libraries:\n",
    "\n",
    "Make sure you use the latest versions\n",
    "\n",
    "```\n",
    "pip install -U transformers accelerate bitsandbytes\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6744a00f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_HOME'] = '/run/cache/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "506fab2a-a50c-42bd-a106-c83a9d2828ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-06-13 19:36:42--  https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 3832 (3.7K) [text/plain]\n",
      "Saving to: ‘minsearch.py’\n",
      "\n",
      "minsearch.py        100%[===================>]   3.74K  --.-KB/s    in 0s      \n",
      "\n",
      "2024-06-13 19:36:42 (50.3 MB/s) - ‘minsearch.py’ saved [3832/3832]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!rm -f minsearch.py\n",
    "!wget https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ac947de-effd-4b61-8792-a6d7a133f347",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<minsearch.Index at 0x7f58c573f970>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import requests \n",
    "import minsearch\n",
    "\n",
    "docs_url = 'https://github.com/DataTalksClub/llm-zoomcamp/blob/main/01-intro/documents.json?raw=1'\n",
    "docs_response = requests.get(docs_url)\n",
    "documents_raw = docs_response.json()\n",
    "\n",
    "documents = []\n",
    "\n",
    "for course in documents_raw:\n",
    "    course_name = course['course']\n",
    "\n",
    "    for doc in course['documents']:\n",
    "        doc['course'] = course_name\n",
    "        documents.append(doc)\n",
    "\n",
    "index = minsearch.Index(\n",
    "    text_fields=[\"question\", \"text\", \"section\"],\n",
    "    keyword_fields=[\"course\"]\n",
    ")\n",
    "\n",
    "index.fit(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8f087272-b44d-4738-9ea2-175ec63a058b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def search(query):\n",
    "    boost = {'question': 3.0, 'section': 0.5}\n",
    "\n",
    "    results = index.search(\n",
    "        query=query,\n",
    "        filter_dict={'course': 'data-engineering-zoomcamp'},\n",
    "        boost_dict=boost,\n",
    "        num_results=5\n",
    "    )\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe8bff3e-b672-42be-866b-f2d9bb217106",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rag(query):\n",
    "    search_results = search(query)\n",
    "    prompt = build_prompt(query, search_results)\n",
    "    answer = llm(prompt)\n",
    "    return answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "def281ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f5891365750>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
    "\n",
    "torch.random.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1a2c29d5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "848f3d8a68b84dc79caf6b612bd559fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/3.55k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f27e92cbc2441df95abd80fbc835d62",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "configuration_phi3.py:   0%|          | 0.00/10.4k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct:\n",
      "- configuration_phi3.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36d6864e25704d519b8ca1583c49d28a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "modeling_phi3.py:   0%|          | 0.00/73.8k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct:\n",
      "- modeling_phi3.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n",
      "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56e94e8225d1442ca876097ee8ff198e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/16.3k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8b8f7ff372a4562b2696bdb104bc512",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdda6929599d4c2789ee39b5e630f843",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "145ded2da5a0481ba4ddaad65af2afb6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c88be68816c54e17a84c168b5ffcce6c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11d66a9f1b8f4f2e80e911cb52383cc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/172 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f556615fb7c64f5cb764e3e865548000",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/3.17k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eeba1a36b07149d5a4c0314f5367deaa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42093b392eff41ed997e79d4ea1f1bd5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ce61494a57946b7ab6d88049451995a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "added_tokens.json:   0%|          | 0.00/293 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed4a42b2fa92450fa0a732402bed30c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/568 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"microsoft/Phi-3-mini-128k-instruct\", \n",
    "    device_map=\"cuda\", \n",
    "    torch_dtype=\"auto\", \n",
    "    trust_remote_code=True, \n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3-mini-128k-instruct\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d63609b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "da66a082",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt(query, search_results):\n",
    "    prompt_template = \"\"\"\n",
    "You're a course teaching assistant. Answer the QUESTION based on the CONTEXT from the FAQ database.\n",
    "Use only the facts from the CONTEXT when answering the QUESTION.\n",
    "\n",
    "QUESTION: {question}\n",
    "\n",
    "CONTEXT:\n",
    "{context}\n",
    "\"\"\".strip()\n",
    "\n",
    "    context = \"\"\n",
    "    \n",
    "    for doc in search_results:\n",
    "        context = context + f\"section: {doc['section']}\\nquestion: {doc['question']}\\nanswer: {doc['text']}\\n\\n\"\n",
    "    \n",
    "    prompt = prompt_template.format(question=query, context=context).strip()\n",
    "    return prompt\n",
    "\n",
    "def llm(prompt):\n",
    "    messages = [\n",
    "        {\"role\": \"user\", \"content\": prompt},\n",
    "    ]\n",
    "\n",
    "    generation_args = {\n",
    "        \"max_new_tokens\": 500,\n",
    "        \"return_full_text\": False,\n",
    "        \"temperature\": 0.0,\n",
    "        \"do_sample\": False,\n",
    "    }\n",
    "\n",
    "    output = pipe(messages, **generation_args)\n",
    "    return output[0]['generated_text'].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "41b774d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Yes, you can still join the course even if you discover it after the start date.'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rag(\"I just discovered the course. Can I still join it?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142cb2eb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "saturn (Python 3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
